import numpy as np
import torch


c = torch.tensor([[2,2],[3,3],[4,4]])
d = torch.tensor([[[1,2]],[[1,2]],[[0,0]]])

print(c[d])

# a = torch.randn(3,256,256)
# b = torch.argmax(a, dim=0, keepdim=True)

# print(b.shape)
# print(b)
# print(a)



# a = np.random.randn(256,256,3)
# a = np.zeros((256,256,3))
# a[0][0] = 1,1,1
# print(a.shape)
# # print((a==(0,0,0)))
# print(np.all(a==(0,0,0),axis=2))